diff --git a/megatron/arguments.py b/megatron/arguments.py
index 102e890..c3504bd 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@ import os
 
 import torch
 
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
 def parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
     """Parse all arguments."""
@@ -43,6 +46,9 @@ def parse_args(extra_args_provider=None, defaults={},
     parser = _add_logging_args(parser)
     parser = _add_inference_args(parser)
 
+    # FastMoE arguments.
+    parser = _add_fmoe_args(parser)
+
     # Custom arguments.
     if extra_args_provider is not None:
         parser = extra_args_provider(parser)
@@ -316,6 +322,12 @@ def parse_args(extra_args_provider=None, defaults={},
     if args.sequence_parallel:
         args.async_tensor_model_parallel_allreduce = False
 
+    # if fmoe_num_experts is not specified,
+    # we are using lower version of megatron,
+    # copy num_experts to fmoe_num_experts
+    if not hasattr(args, 'fmoe_num_experts'):
+        args.fmoe_num_experts = args.num_experts
+
     _print_args(args)
     return args
 
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index ceba352..01754d0 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -124,6 +124,10 @@ def read_metadata(tracker_filename):
                 sys.exit()
     assert iteration > 0 or release, 'error parsing metadata file {}'.format(
         tracker_filename)
+    
+    args = get_args()
+    if args.fmoefy:
+        return iteration, release
 
     # Get the max iteration retrieved across the ranks.
     iters_cuda = torch.cuda.LongTensor([iteration])
@@ -134,6 +138,7 @@ def read_metadata(tracker_filename):
     # If not, print a warning and chose the maximum
     # iteration across all ranks.
     if iteration != max_iter:
+        rank = torch.distributed.get_rank()
         print('WARNING: on rank {} found iteration {} in the '
               'metadata while max iteration across the ranks '
               'is {}, replacing it with max iteration.'.format(
@@ -399,7 +404,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
                     opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
                 else:
                     opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
-        except KeyError:
+        except KeyError as e:
+            print(e)
             print_rank_0('Unable to load optimizer from checkpoint {}. '
                          'Specify --no-load-optim or --finetune to prevent '
                          'attempting to load the optimizer state, '
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 2f6e1b8..e2483db 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@ dtypes = {
     3: np.int16,
     4: np.int32,
     5: np.int64,
-    6: np.float,
+    6: np.float32,
     7: np.double,
     8: np.uint16
 }
@@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object):
         np.int16: 2,
         np.int32: 4,
         np.int64: 8,
-        np.float: 4,
+        np.float32: 4,
         np.double: 8
     }
 
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index d8bee27..6f4ecfb 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -101,8 +101,9 @@ def get_megatron_optimizer(model,
 
     # Determine whether the params have main-grad field.
     params_have_main_grad = False
-    if args.DDP_impl == 'local':
-        params_have_main_grad = True
+    # FastMoE does not have main_grad field
+    # if args.DDP_impl == 'local':
+    #     params_have_main_grad = True
 
     if args.fp16 or args.bf16:
 
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 36cd915..7b4eaaa 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,6 +54,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
     #   - should not be a replica due to tensor model parallelism
     grads = []
     grads_for_norm = []
+    # FastMoE
+    grads_in_moe = []
     for param in parameters:
         grad_not_none = param.grad is not None
         is_not_shared = param_is_not_shared(param)
@@ -65,7 +67,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
             assert param.grad.type() == 'torch.cuda.FloatTensor'
             grads.append(grad)
         if grad_not_none and is_not_shared and is_not_tp_duplicate:
-            grads_for_norm.append(grad)
+            # FastMoE
+            if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+                grads_in_moe.append(grad)
+            else:
+                grads_for_norm.append(grad)
 
     # Norm parameters.
     max_norm = float(max_norm)
@@ -74,6 +80,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
 
     # Calculate norm.
     if norm_type == inf:
+        # FastMoE TODO
+        assert False, f"norm_type {norm_type} is not supported by FastMoE "
         total_norm = max(grad.abs().max() for grad in grads_for_norm)
         total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
         # Take max across all model-parallel GPUs.
@@ -98,7 +106,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
             # we need the pow(norm-type).
             total_norm = grad_norm ** norm_type
 
+            # FastMoE
+            if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+                grad_norm, _ = multi_tensor_applier(
+                    amp_C.multi_tensor_l2norm,
+                    dummy_overflow_buf,
+                    [grads_in_moe],
+                    False # no per-parameter norm
+                )
+                grad_norm = grad_norm ** norm_type
+                torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+                total_norm += grad_norm
         else:
+            # FastMoE TODO
+            assert False, f"norm_type {norm_type} is not supported by FastMoE "
             for grad in grads_for_norm:
                 grad_norm = torch.norm(grad, norm_type)
                 total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index d6ac42e..7eecff4 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -257,6 +257,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
                                                                   param)
                         if hasattr(param, 'shared'):
                             main_param.shared = param.shared
+                        # FastMoE
+                        if hasattr(param, 'dp_comm'):
+                            main_param.dp_comm = param.dp_comm
                         # Replace the optimizer params with the new fp32 copy.
                         param_group['params'][i] = main_param
                         fp32_from_float16_params_this_group.append(main_param)
@@ -411,18 +414,27 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
             # We are done with scaling gradients
             # so we can update the loss scale.
             self.grad_scaler.update(found_inf_flag)
-
+            
+            # move to L433-L436
             # If we found inf/nan, skip the update.
-            if found_inf_flag:
-                return False, None, None
+            # if found_inf_flag:
+            #    return False, None, None
 
         # Clip the main gradients.
         timers('optimizer-clip-main-grad').start()
         grad_norm = None
-        if self.clip_grad > 0.0:
-            grad_norm = self.clip_grad_norm(self.clip_grad)
+
+        # remove if branch to avoid dead-lock in FastMoE
+        # if self.clip_grad > 0.0:
+        #     grad_norm = self.clip_grad_norm(self.clip_grad)
+        grad_norm = self.clip_grad_norm(self.clip_grad)
         timers('optimizer-clip-main-grad').stop()
 
+        # move early return to here to avoid dead-lock in FastMoE 
+        # If we found inf/nan, skip the update.
+        if found_inf_flag:
+            return False, None, None
+
         # count the zeros in the grads
         num_zeros_in_grad = self.count_zeros() if \
                             self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index ac5ba6f..26b717a 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -24,7 +24,10 @@ from megatron import get_timers
 from megatron import mpu
 from megatron import p2p_communication
 from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
 from megatron.model import Float16Module
 from megatron.model import ModelType
 
@@ -66,7 +69,7 @@ def deallocate_output_tensor(out):
         dtype = out.dtype,
     )
         
-def custom_backward(output, grad_output):
+def custom_backward(output, grad_output, bal_loss):
     '''Directly call C++ autograd engine.
 
     To make the 'deallocate_output_tensor' (above) optimization work, the C++
@@ -89,11 +92,16 @@ def custom_backward(output, grad_output):
             output,
             memory_format = torch.preserve_format,
         )
+        tensors = (output,)
+        grad_tensors = (grad_output,)
+    else:
+        tensors = (output, bal_loss)
+        grad_tensors = (grad_output, None)
 
     # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
     Variable._execution_engine.run_backward(
-        tensors = (output,),
-        grad_tensors = (grad_output,),
+        tensors = tensors,
+        grad_tensors = grad_tensors,
         keep_graph = False,
         create_graph = False,
         inputs = tuple(),
@@ -127,7 +135,8 @@ def forward_step(forward_step_func,
         unwrap_output_tensor = True
 
     unwrapped_model.set_input_tensor(input_tensor)
-    output_tensor, loss_func = forward_step_func(data_iterator, model)
+    output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+    bal_loss = bal_loss / get_num_microbatches()
     if mpu.is_pipeline_last_stage():
         if not collect_non_loss_data:
             output_tensor = loss_func(output_tensor)
@@ -145,13 +154,14 @@ def forward_step(forward_step_func,
     # downstream as well.
     if mpu.is_pipeline_stage_after_split() and \
             args.model_type == ModelType.encoder_and_decoder:
+        assert False, f"encoder-decoder model is not supported by FastMoE "
         return [output_tensor, input_tensor[-1]]
     if unwrap_output_tensor:
-        return output_tensor
-    return [output_tensor]
+        return output_tensor, bal_loss
+    return [output_tensor, bal_loss]
 
 
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
     """Backward step through passed-in output tensor.
 
     If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -185,7 +195,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
     # Backward pass.
     if output_tensor_grad[0] is None:
         output_tensor = optimizer.scale_loss(output_tensor[0])
-    custom_backward(output_tensor[0], output_tensor_grad[0])
+    custom_backward(output_tensor[0], output_tensor_grad[0], bal_loss)
 
     # Collect the grad of the input_tensor.
     input_tensor_grad = [None]
@@ -241,20 +251,20 @@ def forward_backward_no_pipelining(forward_step_func,
     input_tensor, output_tensor_grad = None, None
     with context_handler():
         for i in range(get_num_microbatches() - 1):
-            output_tensor = forward_step(forward_step_func, data_iterator,
+            output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
                                          model, input_tensor, forward_data_store,
                                          collect_non_loss_data)
             if not forward_only:
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
     # Run computation for last microbatch out of context handler (want to
     # synchronize gradients).
-    output_tensor = forward_step(forward_step_func, data_iterator,
+    output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
                                  model, input_tensor, forward_data_store,
                                  collect_non_loss_data)
     if not forward_only:
-        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
 
     return forward_data_store
 
@@ -269,6 +279,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
     communication between pipeline stages as needed.
 
     Returns dictionary with losses if the last stage, empty dict otherwise."""
+    # FastMoE TODO
+    assert False, "FastMoE not supports pipeline with interleaving"
     input_tensors = [[] for _ in range(len(model))]
     output_tensors = [[] for _ in range(len(model))]
     forward_data_store = []
@@ -646,15 +658,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
     # Input, output tensors only need to be saved when doing backward passes
     input_tensors = None
     output_tensors = None
+    bal_losses = None
     if not forward_only:
         input_tensors = []
         output_tensors = []
+        bal_losses = []
     forward_data_store = []
 
     # Run warmup forward passes.
     for i in range(num_warmup_microbatches):
         input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
-        output_tensor = forward_step(forward_step_func, data_iterator, model,
+        output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, forward_data_store,
                                      collect_non_loss_data)
         send_forward(output_tensor, send_tensor_shapes, timers=timers)
@@ -662,6 +676,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
         if not forward_only:
             input_tensors.append(input_tensor)
             output_tensors.append(output_tensor)
+            bal_losses.append(bal_loss)
             deallocate_output_tensor(output_tensor[0])
 
     # Before running 1F1B, need to receive first forward tensor.
@@ -674,7 +689,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
     for i in range(num_microbatches_remaining):
         last_iteration = (i == (num_microbatches_remaining - 1))
 
-        output_tensor = forward_step(forward_step_func, data_iterator, model,
+        output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, forward_data_store,
                                      collect_non_loss_data)
         if forward_only:
@@ -692,16 +707,18 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
             # Add input_tensor and output_tensor to end of list.
             input_tensors.append(input_tensor)
             output_tensors.append(output_tensor)
+            bal_losses.append(bal_loss)
             deallocate_output_tensor(output_tensor[0])
 
             # Pop input_tensor and output_tensor from the start of the list for
             # the backward pass.
             input_tensor = input_tensors.pop(0)
             output_tensor = output_tensors.pop(0)
+            bal_loss = bal_loss.pop(0)
 
             input_tensor_grad = \
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
             if last_iteration:
                 input_tensor = None
@@ -716,12 +733,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
         for i in range(num_warmup_microbatches):
             input_tensor = input_tensors.pop(0)
             output_tensor = output_tensors.pop(0)
+            bal_loss = bal_losses.pop(0)
 
             output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
 
             input_tensor_grad = \
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
             send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
 
diff --git a/megatron/training.py b/megatron/training.py
index 023bdf1..caefb88 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -36,8 +36,13 @@ from megatron import update_num_microbatches
 from megatron import mpu
 from megatron import print_rank_0
 from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
 from megatron.model import Float16Module
 from megatron.model import ModelType
 from megatron.optimizer import get_megatron_optimizer
@@ -45,7 +50,11 @@ from megatron.initialize import initialize_megatron
 from megatron.initialize import write_args_to_tensorboard
 from megatron.initialize import set_jit_fusion_options
 from megatron.optimizer_param_scheduler import OptimizerParamScheduler
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
 from megatron.utils import check_adlr_autoresume_termination
 from megatron.utils import unwrap_model
 from megatron.data.data_samplers import build_pretraining_data_loader
@@ -119,6 +128,13 @@ def pretrain(train_valid_test_dataset_provider,
     args = get_args()
     timers = get_timers()
 
+    # Initialize FastMoE
+    if args.fmoefy:
+        from fmoe.megatron import patch_forward_step, patch_model_provider
+
+        forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v3.0.2")
+        model_provider = patch_model_provider(model_provider, Megatron_Version='v3.0.2')
+
     # Model, optimizer, and learning rate.
     timers('model-and-optimizer-setup').start()
     model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
@@ -466,10 +482,12 @@ def train_step(forward_step_func, data_iterator,
 
         if unwrapped_model.share_word_embeddings:
             word_embeddings_weight = unwrapped_model.word_embeddings_weight()
-            if args.DDP_impl == 'local':
-                grad = word_embeddings_weight.main_grad
-            else:
-                grad = word_embeddings_weight.grad
+            grad = word_embeddings_weight.grad
+            # FastMoE does not have main_grad field
+            # if args.DDP_impl == 'local':
+            #     grad = word_embeddings_weight.main_grad
+            # else:
+            #     grad = word_embeddings_weight.grad
             torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
 
     # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
@@ -568,26 +586,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
     # Logging.
     timers_to_log = []
 
-    def add_to_logging(name):
-        if name in timers.timers:
-            timers_to_log.append(name)
-    add_to_logging('forward-compute')
-    add_to_logging('forward-recv')
-    add_to_logging('forward-send')
-    add_to_logging('forward-backward-send-forward-backward-recv')
-    add_to_logging('backward-compute')
-    add_to_logging('backward-recv')
-    add_to_logging('backward-send')
-    add_to_logging('backward-send-forward-recv')
-    add_to_logging('backward-send-backward-recv')
-    add_to_logging('backward-params-all-reduce')
-    add_to_logging('backward-embedding-all-reduce')
-    add_to_logging('optimizer-copy-to-main-grad')
-    add_to_logging('optimizer-unscale-and-check-inf')
-    add_to_logging('optimizer-clip-main-grad')
-    add_to_logging('optimizer-copy-main-to-model-params')
-    add_to_logging('optimizer')
-    add_to_logging('batch-generator')
+    # FastMoE add several timers.
+    # For simplicity, add all timers to log.
+    def add_all():
+        for name in timers.timers:
+             timers_to_log.append(name)
+    
+    add_all()
 
     # Calculate batch size.
     batch_size = args.micro_batch_size * args.data_parallel_size * \
-- 
2.25.1

